import argparse
import os
import random
import numpy as np
import torch
import ast
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments
)
from torch import nn
from torch.utils.data import DataLoader, Dataset

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# Synthetic Data Generation Functions
def generate_synthetic_example(seq_length=10, K=5, noise_range=100, idx=None):
    tokens = []
    # Choose s0 and a non-zero common difference d.
    s0 = random.randint(0, K - 1)
    d = random.randint(1, K - 1)  # ensure difference is not zero
    signals = []
    
    for t in range(seq_length):
        # Compute the t-th signal in the arithmetic progression.
        s_t = (s0 + t * d) % K
        signals.append(s_t)
        
        # Generate noise for this token.
        noise = random.randint(0, noise_range - 1)
        token = f"S{s_t}_N{noise}"
        tokens.append(token)
    
    # Create context from all tokens except the last one.
    context = " ".join(tokens[:-1])
    # The target is the signal of the final token as a string.
    target = str(signals[-1])
    
    example = {"context": context, "target": target}
    if idx is not None:
        example["id"] = str(idx)
    return example

def generate_synthetic_dataset(num_examples, seq_length=10, K=5, noise_range=100):
    return [generate_synthetic_example(seq_length, K, noise_range, idx=i) for i in range(num_examples)]

class SyntheticDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

def collate_fn(batch, tokenizer):
    """
    Collate function for synthetic dataset.
    For each example, we form a prompt (context) and the target, then create labels
    such that only the target token (appended at the end) is predicted.
    """
    inputs = []
    labels = []
    ids = []
    raw_contexts = []
    raw_targets = []

    for i, item in enumerate(batch):
        context = item["context"]
        target = item["target"]
        sample_id = item.get("id", str(i))
        raw_contexts.append(context)
        raw_targets.append(target)
        
        prompt = context  # prompt is just the context
        combined = prompt + " " + target  # combine prompt and target
        
        tokenized_combined = tokenizer(
            combined,
            add_special_tokens=True,
            return_tensors=None
        )
        tokenized_prompt = tokenizer(
            prompt,
            add_special_tokens=True,
            return_tensors=None
        )

        prompt_len = len(tokenized_prompt["input_ids"])
        input_ids = tokenized_combined["input_ids"]
        sample_labels = [-100] * prompt_len + input_ids[prompt_len:]

        inputs.append(input_ids)
        labels.append(sample_labels)
        ids.append(sample_id)

    batch_enc = tokenizer.pad(
        {"input_ids": inputs},
        padding=True,
        return_attention_mask=True,
        return_tensors="pt",
    )

    max_len = batch_enc["input_ids"].size(1)
    labels_padded = [l + [-100] * (max_len - len(l)) for l in labels]
    batch_enc["labels"] = torch.tensor(labels_padded, dtype=torch.long)
    batch_enc["ids"] = ids
    batch_enc["raw_context"] = raw_contexts
    batch_enc["raw_target"] = raw_targets

    return batch_enc

class GPT2WithResidualBeta(nn.Module):
    """GPT2 model with learnable beta parameters for residual connections."""
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        
        # Freeze all parameters of the base model
        for param in self.base_model.parameters():
            param.requires_grad = False
        
        # Add beta parameter for each layer
        num_layers = len(self.base_model.transformer.h)
        self.betas = nn.Parameter(torch.ones(num_layers))
        
        # Register hooks for each transformer block
        self.hooks = []
        self._register_hooks()
        
    def _register_hooks(self):
        """Register hooks to modify residual connections with beta parameters."""
        def get_hook_fn(layer_idx):
            def hook_fn(module, input_tensors, output_tensors):
                # Apply beta to the output: output = beta * output
                beta = self.betas[layer_idx]
                modified_output = output_tensors * beta
                return modified_output
            return hook_fn
        
        # Register hooks for all transformer blocks
        for i, block in enumerate(self.base_model.transformer.h):
            hook = block.mlp.register_forward_hook(get_hook_fn(i))
            self.hooks.append(hook)
    
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        """Forward pass delegating to the base model."""
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )
        return outputs
    
    def print_beta_values(self):
        """Print the current beta parameter values."""
        beta_values = self.betas.detach().cpu().numpy()
        print("Beta values:")
        for i, beta in enumerate(beta_values):
            print(f"Layer {i}: {beta:.4f}")
        print("-" * 30) 

def evaluate_model(model, tokenizer, device, examples):
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    model = model.to(device)
    """Evaluate model on test dataset using a dedicated test data processor."""
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0
    
    # Process the test data with our dedicated function
    processed_examples = examples
    print(f"Processed {len(processed_examples)} test examples")
    
    # Print an example for debugging
    if processed_examples:
        print("Example processed test item:")
        print(f"Prompt: {processed_examples[0]['prompt'][:100]}...")
        print(f"Target: {processed_examples[0]['target_text']}")
    
    with torch.no_grad():
        for batch_idx in range(0, len(processed_examples), 16):  # Small eval batch size
            batch = processed_examples[batch_idx:batch_idx + 16]
            
            # Process each example in the batch
            inputs = []
            labels = []
            
            for item in batch:
                prompt = item["prompt"]
                combined = item["combined"]
                
                tokenized_combined = tokenizer(combined, add_special_tokens=True, return_tensors=None)
                tokenized_prompt = tokenizer(prompt, add_special_tokens=True, return_tensors=None)
                
                prompt_len = len(tokenized_prompt["input_ids"])
                input_ids = tokenized_combined["input_ids"]
                sample_labels = [-100] * prompt_len + input_ids[prompt_len:]
                
                inputs.append(input_ids)
                labels.append(sample_labels)
            
            # Create batch encodings
            batch_enc = tokenizer.pad(
                {"input_ids": inputs},
                padding=True,
                return_attention_mask=True,
                return_tensors="pt",
            )
            
            max_len = batch_enc["input_ids"].size(1)
            labels_padded = [l + [-100] * (max_len - len(l)) for l in labels]
            batch_enc["labels"] = torch.tensor(labels_padded, dtype=torch.long)
            
            # Move tensors to device
            input_ids = batch_enc["input_ids"].to(device)
            attention_mask = batch_enc["attention_mask"].to(device)
            labels = batch_enc["labels"].to(device)
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            batch_loss = outputs.loss.item()
            total_loss += batch_loss * len(batch)
            
            # Calculate accuracy
            logits = outputs.logits  # (batch_size, seq_len, vocab_size)
            first_label_positions = (labels != -100).float().argmax(dim=1)  # (batch_size,)
            
            for i in range(logits.size(0)):
                pos = first_label_positions[i].item()
                if pos > 0 and pos < labels.size(1) and labels[i, pos] != -100:  # Avoid edge cases
                    pred_token = logits[i, pos - 1].argmax(dim=-1)  # prediction for next token
                    true_token = labels[i, pos]
                    
                    if pred_token == true_token:
                        correct_predictions += 1
                    total_samples += 1
            
            # Print debug info for first batch
            if batch_idx == 0:
                print(f"First batch example:")
                example_idx = 0
                example_prompt = tokenizer.decode(input_ids[example_idx][:first_label_positions[example_idx].item()])
                example_true_token = tokenizer.decode(labels[example_idx, first_label_positions[example_idx].item()].unsqueeze(0))
                example_pred_token = tokenizer.decode(logits[example_idx, first_label_positions[example_idx].item() - 1].argmax(dim=-1).unsqueeze(0))
                print(f"Prompt:\n{example_prompt}")
                print(f"True next token: '{example_true_token}'")
                print(f"Predicted next token: '{example_pred_token}'")
                print("-" * 30)
    
    avg_loss = total_loss / len(processed_examples) if processed_examples else 0
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0
    
    print(f"Total correct: {correct_predictions}, Total samples: {total_samples}")
    print(f"Final Accuracy: {accuracy:.4f}")
    
    return avg_loss, accuracy

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--checkpoint_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="path/to/your/folder")
    args = parser.parse_args()

    set_seed(args.seed)
    device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
    print(device)

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Save training configuration
    config_path = os.path.join(args.output_dir, "path/to/your/file")
    with open(config_path, "w") as f:
        for arg in vars(args):
            f.write(f"{arg}: {getattr(args, arg)}\n")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Generate synthetic training data
    print(f"Generating 10000 synthetic training examples with K=13")
    train_examples = generate_synthetic_dataset(10000, seq_length=10, K=13, noise_range=100)
    train_dataset = SyntheticDataset(train_examples)

    # Load model from checkpoint
    print(f"Loading model from checkpoint: {args.checkpoint_dir}")
    base_model = AutoModelForCausalLM.from_pretrained(args.checkpoint_dir)
    base_model = base_model.to(device)

    # Evaluate model
    print("Evaluating model on test dataset...")

    # Preprocess training examples into evaluation format
    processed_train_examples = []
    for ex in train_examples:
        prompt = ex["context"]
        combined = prompt + " " + ex["target"]
        processed_train_examples.append({
            "prompt": prompt,
            "combined": combined,
            "target_text": ex["target"]
        })

    test_loss, test_accuracy = evaluate_model(base_model, tokenizer, device, processed_train_examples)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    # Save metrics
    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        f.write(f"Test Loss: {test_loss:.4f}\n")
        f.write(f"Test Accuracy: {test_accuracy:.4f}\n")
    
    # Create model with beta parameters
    model = GPT2WithResidualBeta(base_model)
    model.to(device)
    
    # Print initial beta values
    print("Initial beta values:")
    model.print_beta_values()

    # Training arguments
    training_args = TrainingArguments(
        seed=args.seed,
        data_seed=args.seed,
        output_dir=args.output_dir,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        learning_rate=args.lr,
        weight_decay=0.01,
        remove_unused_columns=False,
        report_to=["none"],
        save_strategy="no"
    )

    # Initialize trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=lambda b: collate_fn(b, tokenizer),
    )

    # Train the model (only beta parameters will be updated)
    trainer.train()
    
    # Print final beta values
    print("Final beta values after training:")
    model.print_beta_values()
    
    # Save final beta values
    torch.save({
        "betas": model.betas.detach().cpu(),
    }, os.path.join(args.output_dir, "path/to/your/file"))
    
    print(f"Training complete. Beta parameters saved.")

    # Save beta values to text file
    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        for i, beta in enumerate(model.betas.detach().cpu().numpy()):
            f.write(f"Layer {i}: {beta:.6f}\n")

    # Evaluate model
    print("Evaluating model on test dataset...")
    test_loss, test_accuracy = evaluate_model(model, tokenizer, device, processed_train_examples)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")

    # Save metrics
    with open(os.path.join(args.output_dir, "path/to/your/file"), "w") as f:
        f.write(f"Test Loss: {test_loss:.4f}\n")
        f.write(f"Test Accuracy: {test_accuracy:.4f}\n")


if __name__ == "__main__":
    main()